# -*- coding: utf-8 -*-
"""
Created on Thu Jun 25 12:48:36 2020

@author: wangrong
"""

from torchvision import transforms
from train import Classification
from PIL import Image
import models.regnet
from models import *

clspre = Classification(model_name='./history/vgg0.01/model300.pth',train_net=VGG('VGG16'))
clspre1 = Classification(model_name='./history/senet0.01/model300.pth', train_net=SENet18())
print(clspre1.net)
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])


# 测试图片
address = './verify_image/cuixiang/'
img = Image.open(address+'0.jpg')
result = clspre.predict(img,transform).cpu().numpy()
if(result[0]==1):
    print('The kiwi from '+address+' is 瑞玉')
elif(result[0]==2):
    print('The kiwi from '+address+' is 徐香')
else:
    print('The kiwi from '+address+' is 翠香')

address = './verify_image/xuxiang/'
img = Image.open(address+'0.jpg')
result = clspre.predict(img,transform).cpu().numpy()
if(result[0]==1):
    print('The kiwi from '+address+' is 瑞玉')
elif(result[0]==2):
    print('The kiwi from '+address+' is 徐香')
else:
    print('The kiwi from '+address+' is 翠香')

address = './verify_image/ruiyu/'
img = Image.open(address+'0.jpg')
result = clspre.predict(img,transform).cpu().numpy()
if(result[0]==1):
    print('The kiwi from '+address+' is 瑞玉')
elif(result[0]==2):
    print('The kiwi from '+address+' is 徐香')
else:
    print('The kiwi from '+address+' is 翠香')

